{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook computes LRP (Layer-wise Relevance Propagation), SA (Sensitivity Analysis) and GI (GradientxInput) relevances on an exemplary test sentence, and for a chosen relevance *target* class, using a trained bidirectional LSTM, that was trained on the Stanford Sentiment Treebank (SST) dataset.\n",
    "\n",
    "The LRP implementation is based on the following papers:\n",
    "- [https://doi.org/10.1371/journal.pone.0130140](https://doi.org/10.1371/journal.pone.0130140)\n",
    "- [https://doi.org/10.18653/v1/W17-5221](https://doi.org/10.18653/v1/W17-5221)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from code.LSTM.LSTM_bidi import * \n",
    "from code.util.heatmap import html_heatmap\n",
    "\n",
    "import codecs\n",
    "import numpy as np\n",
    "from IPython.display import display, HTML"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define input sequence and relevance target class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The sentiment classes are encoded the following way:  \n",
    "**0=very negative, 1=negative, 2=neutral, 3=positive, 4=very positive**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_test_sentence(sent_idx):\n",
    "    \"\"\"Returns an SST test set sentence and its true label, sent_idx must be an integer in [1, 2210]\"\"\"\n",
    "    idx = 1\n",
    "    with codecs.open(\"./data/sequence_test.txt\", 'r', encoding='utf8') as f:\n",
    "        for line in f:\n",
    "            line          = line.rstrip('\\n')\n",
    "            line          = line.split('\\t')\n",
    "            true_class    = int(line[0])-1         # true class\n",
    "            words         = line[1].split(' | ')   # sentence as list of words\n",
    "            if idx == sent_idx:\n",
    "                return words, true_class\n",
    "            idx +=1\n",
    "\n",
    "def predict(words):\n",
    "    \"\"\"Returns the classifier's predicted class\"\"\"\n",
    "    net                 = LSTM_bidi()                                   # load trained LSTM model\n",
    "    w_indices           = [net.voc.index(w) for w in words]             # convert input sentence to word IDs\n",
    "    net.set_input(w_indices)                                            # set LSTM input sequence\n",
    "    scores              = net.forward()                                 # classification prediction scores\n",
    "    return np.argmax(scores)            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As an input sequence, either select a sentence from the Stanford Sentiment Treebank (SST) test set, or define your own sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "words, _ = get_test_sentence(291)                                       # SST test set sentence number 291"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Alternatively, uncomment one of the following sentences, or define your own sequence (only words contained in the vocabulary are supported!)\n",
    "#words = ['this','movie','was','actually','neither','that','funny',',','nor','super','witty','.']\n",
    "#words = ['this', 'film', 'does', 'n\\'t', 'care', 'about', 'cleverness', ',', 'wit', 'or', 'any', 'other', 'kind', 'of', 'intelligent', 'humor', '.']\n",
    "#words = ['i','hate','the','movie','though','the','plot','is','interesting','.']\n",
    "#words = ['used', 'to', 'be', 'my', 'favorite']\n",
    "#words = ['not', 'worth', 'the', 'time']\n",
    "#words = ['is', 'n\\'t', 'a', 'bad', 'film'] # Note: misclassified sample!\n",
    "#words = ['is', 'n\\'t', 'very', 'interesting'] \n",
    "#words = ['it', '\\'s', 'easy' ,'to' ,'love' ,'robin' ,'tunney' ,'--' ,'she' ,'\\'s' ,'pretty' ,'and' ,'she' ,'can' ,'act' ,'--' ,'but' ,'it' ,'gets' ,'harder' ,'and' ,'harder' ,'to' ,'understand' ,'her' ,'choices', '.']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to understand the classification/misclassification of single samples, we highly **recommend using the classifier's *predicted* class as the relevance *target* class**, since it's the class the model is the most confident about, and therefore this setup will reflect the classifier's \"point of view\" on the test sample more accurately.\n",
    "(More generally, it is possible to choose any class as the relevance *target* class.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "predicted_class = predict(words)                                        # get predicted class\n",
    "target_class    = predicted_class                                       # define relevance target class "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['neither', 'funny', 'nor', 'suspenseful', 'nor', 'particularly', 'well-drawn', '.']\n",
      "\n",
      "predicted class:           0\n"
     ]
    }
   ],
   "source": [
    "print (words)\n",
    "print (\"\\npredicted class:          \",   predicted_class)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute LRP relevances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LRP hyperparameters:\n",
    "eps                 = 0.001                                             # small positive number\n",
    "bias_factor         = 0.0                                               # recommended value\n",
    " \n",
    "net                 = LSTM_bidi()                                       # load trained LSTM model\n",
    "\n",
    "w_indices           = [net.voc.index(w) for w in words]                 # convert input sentence to word IDs\n",
    "Rx, Rx_rev, R_rest  = net.lrp(w_indices, target_class, eps, bias_factor)# perform LRP\n",
    "R_words             = np.sum(Rx + Rx_rev, axis=1)                       # compute word-level LRP relevances\n",
    "\n",
    "scores              = net.s.copy()                                      # classification prediction scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prediction scores:         [ 2.73149687  2.7249559   0.80547211 -1.5359282  -4.6083298 ]\n",
      "\n",
      "LRP target class:          0\n",
      "\n",
      "LRP relevances:\n",
      "\t\t\t    1.86\tneither\n",
      "\t\t\t   -1.58\tfunny\n",
      "\t\t\t    1.50\tnor\n",
      "\t\t\t   -1.54\tsuspenseful\n",
      "\t\t\t    2.00\tnor\n",
      "\t\t\t   -0.04\tparticularly\n",
      "\t\t\t   -0.06\twell-drawn\n",
      "\t\t\t   -0.12\t.\n",
      "\n",
      "LRP heatmap:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<span style=\"background-color:#ff1111\">neither</span> <span style=\"background-color:#3434ff\">funny</span> <span style=\"background-color:#ff4040\">nor</span> <span style=\"background-color:#3a3aff\">suspenseful</span> <span style=\"background-color:#ff0000\">nor</span> <span style=\"background-color:#fafaff\">particularly</span> <span style=\"background-color:#f8f8ff\">well-drawn</span> <span style=\"background-color:#f0f0ff\">.</span> \n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print (\"prediction scores:        \",   scores)\n",
    "print (\"\\nLRP target class:         \", target_class)\n",
    "print (\"\\nLRP relevances:\")\n",
    "for idx, w in enumerate(words):\n",
    "    print (\"\\t\\t\\t\" + \"{:8.2f}\".format(R_words[idx]) + \"\\t\" + w)\n",
    "print (\"\\nLRP heatmap:\")    \n",
    "display(HTML(html_heatmap(words, R_words)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.7314968677959888\n",
      "Sanity check passed?  True\n"
     ]
    }
   ],
   "source": [
    "# How to sanity check global relevance conservation:\n",
    "bias_factor        = 1.0                                             # value to use for sanity check\n",
    "Rx, Rx_rev, R_rest = net.lrp(w_indices, target_class, eps, bias_factor)\n",
    "R_tot              = Rx.sum() + Rx_rev.sum() + R_rest.sum()          # sum of all \"input\" relevances\n",
    "\n",
    "print(R_tot)       ;    print(\"Sanity check passed? \", np.allclose(R_tot, net.s[target_class]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute SA/GI relevances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "net              = LSTM_bidi()                                       # load trained LSTM model\n",
    "\n",
    "w_indices        = [net.voc.index(w) for w in words]                 # convert input sentence to word IDs\n",
    "Gx, Gx_rev       = net.backward(w_indices, target_class)             # perform gradient backpropagation\n",
    "R_words_SA       = (np.linalg.norm(Gx + Gx_rev, ord=2, axis=1))**2   # compute word-level Sensitivity Analysis relevances\n",
    "R_words_GI       = ((Gx + Gx_rev)*net.x).sum(axis=1)                 # compute word-level GradientxInput relevances\n",
    "\n",
    "scores           = net.s.copy()                                      # classification prediction scores "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "prediction scores:        [ 2.73149687  2.7249559   0.80547211 -1.5359282  -4.6083298 ]\n",
      "\n",
      "SA/GI target class:       0\n",
      "\n",
      "SA relevances:\n",
      "\t\t\t    5.01\tneither\n",
      "\t\t\t    0.35\tfunny\n",
      "\t\t\t    0.73\tnor\n",
      "\t\t\t    0.92\tsuspenseful\n",
      "\t\t\t    1.66\tnor\n",
      "\t\t\t    0.13\tparticularly\n",
      "\t\t\t    0.66\twell-drawn\n",
      "\t\t\t    0.32\t.\n",
      "\n",
      "SA heatmap:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<span style=\"background-color:#ff0000\">neither</span> <span style=\"background-color:#ffeeee\">funny</span> <span style=\"background-color:#ffdada\">nor</span> <span style=\"background-color:#ffd0d0\">suspenseful</span> <span style=\"background-color:#ffaaaa\">nor</span> <span style=\"background-color:#fff8f8\">particularly</span> <span style=\"background-color:#ffdede\">well-drawn</span> <span style=\"background-color:#ffeeee\">.</span> \n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "GI relevances:\n",
      "\t\t\t    0.03\tneither\n",
      "\t\t\t    0.06\tfunny\n",
      "\t\t\t   -0.11\tnor\n",
      "\t\t\t   -0.19\tsuspenseful\n",
      "\t\t\t   -0.19\tnor\n",
      "\t\t\t   -0.07\tparticularly\n",
      "\t\t\t   -0.06\twell-drawn\n",
      "\t\t\t    0.03\t.\n",
      "\n",
      "GI heatmap:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<span style=\"background-color:#ffdede\">neither</span> <span style=\"background-color:#ffb3b3\">funny</span> <span style=\"background-color:#6868ff\">nor</span> <span style=\"background-color:#0000ff\">suspenseful</span> <span style=\"background-color:#0000ff\">nor</span> <span style=\"background-color:#a2a2ff\">particularly</span> <span style=\"background-color:#acacff\">well-drawn</span> <span style=\"background-color:#ffdcdc\">.</span> \n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print (\"prediction scores:       \",   scores)\n",
    "print (\"\\nSA/GI target class:      \", target_class)\n",
    "print (\"\\nSA relevances:\")\n",
    "for idx, w in enumerate(words):\n",
    "    print (\"\\t\\t\\t\" + \"{:8.2f}\".format(R_words_SA[idx]) + \"\\t\" + w)\n",
    "print (\"\\nSA heatmap:\")    \n",
    "display(HTML(html_heatmap(words, R_words_SA)))\n",
    "print (\"\\nGI relevances:\")\n",
    "for idx, w in enumerate(words):\n",
    "    print (\"\\t\\t\\t\" + \"{:8.2f}\".format(R_words_GI[idx]) + \"\\t\" + w)\n",
    "print (\"\\nGI heatmap:\")    \n",
    "display(HTML(html_heatmap(words, R_words_GI)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
